import torchvision.models as models
import torch
import numpy as np
import torchvision.transforms as transforms
import os
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import pandas as pd
from skimage import transform as skTr
import scipy.io
from itertools import chain
identity = lambda x:x
import random
SEED = 2021
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)


class Class_Dataset():
    def __init__(self, split, transform=transforms.ToTensor(), size=224, renumber=False, attribute_batchsize=1, attribute_shuffle=False, dataset_root_path='/ABC/DEF/', n_episode=600):

        self.img_label = pd.read_table(dataset_root_path+'/CUB_200_2011/image_class_labels.txt',sep=' ',header=None).values[...,1]-1        
        self.corresponding_path = pd.read_table(dataset_root_path+'/CUB_200_2011/images.txt',sep=' ',header=None).values[...,1:]
        
        self.class_name = pd.read_table(dataset_root_path+'/CUB_200_2011/classes.txt',sep=' ',header=None).values[...,1:][...,0]
        
        self.image_path = np.stack([dataset_root_path+'/CUB_200_2011/images/'+item for item in self.corresponding_path],axis=0)[...,0]
        self.attribute = pd.read_table(dataset_root_path+'/CUB_200_2011/attributes/class_attribute_labels_continuous.txt',sep=' ',header=None).values/100.
        self.split_mat = scipy.io.loadmat('./data/CUB/att_splits.mat')  
        self.attribute_name = pd.read_table(dataset_root_path+'/CUB_200_2011/attributes.txt',sep=' ',header=None).values[...,1]
        self.object_xywh = pd.read_table(dataset_root_path+'/CUB_200_2011/bounding_boxes.txt',sep=' ',header=None).values[...,1:]
        self.image_attribute_labels = pd.read_csv('./data/CUB/image_attribute_labels.csv',header=None).values
        self.attribute_batchsize = attribute_batchsize
        
        self.size = self.image_size = size
        
        if(split=='train'):
            self.img_id_list = self.split_mat['train_loc']-1
        elif(split=='train_val'):
            self.img_id_list = self.split_mat['trainval_loc']-1
        elif(split=='val'):
            self.img_id_list = self.split_mat['val_loc']-1
        elif(split=='test_seen'):
            self.img_id_list = self.split_mat['test_seen_loc']-1
        elif(split=='test_unseen'):
            self.img_id_list = self.split_mat['test_unseen_loc']-1
        elif(split=='all'):
            self.img_id_list = np.array([[i] for i in range(self.img_label.shape[0])])
        else:
            raise "The selection must in the set of {'train','train_val','val','test_seen','test_unseen'}"
        
        self.img_id_list=self.img_id_list[:,0]
        
        self.n_episode=n_episode
        self.attr_list = [i for i in range(312)]

        self.cl_list = np.unique(self.img_label[self.img_id_list]).tolist()
        
        attribute_name = self.attribute_name = pd.read_table(dataset_root_path+'/CUB_200_2011/attributes.txt',sep=' ',header=None).values[...,1]
        self.renumbered_label = np.sort(np.unique(self.img_label[self.img_id_list]))
        self.class_idx = np.unique(self.img_label[self.img_id_list])
        self.label_transform_list = (np.ones([200])*-1).astype(int)
        for key, value in enumerate(self.renumbered_label):
            self.label_transform_list[value] = key
        self.renumber = renumber
        
        
        part_name = {'back':[0],
                    'beak':[1],
                    'belly':[2],
                    'breast':[3],
                    'crown':[4],
                    'forehead':[5],
                    'eye':[6,10],
                    'head':[4,5,6,10],
                    'leg':[7,11],
                    'wing':[8,12],
                    'nape':[9],
                    'tail':[13],
                    'throat':[14],
                    'upperparts':[0,3,4,5,6,10,14],
                    'underparts':[2,8,12,13],
                    'primary':[0,2,3,4,5,6,8,9,10,12,13,14]}
          
        attribute_part_label = np.zeros([312,15])
    
        for key_attr, attr_n in enumerate(attribute_name):
            for key_part, part_n in enumerate(part_name):

                if('head' in attr_n and 'forehead' not in attr_n):
                    for counter in part_name['head']:
                        attribute_part_label[key_attr, counter]=1
                
                elif('forehead' in attr_n):
                    for counter in part_name['forehead']:
                        attribute_part_label[key_attr, counter]=1
                        
                elif(part_n in attr_n):
                    for counter in part_name[part_n]:
                        attribute_part_label[key_attr, counter]=1

        self.attribute_part_label = attribute_part_label       
        
        
                
        def read_parts(filename):
            id_to_parts = dict()
            with open(filename, 'r') as fin:
                for line in fin.readlines():
                    line_split = line.strip().split(' ')
                    img_id, part_id, x, y, visible = int(line_split[0]), int(line_split[1]), float(line_split[2]), float(line_split[3]), int(line_split[4])
                    if part_id == 1:
                        id_to_parts[img_id] = [[x, y, visible], ]
                    else:
                        id_to_parts[img_id].append([x, y, visible])
            return id_to_parts
        
        id_to_parts = read_parts(dataset_root_path+'/CUB_200_2011/parts/part_locs.txt')
        self.id_to_parts = np.array([id_to_parts[N] for N in id_to_parts])

        self.sub_meta = []
        for x,l,A,z,B in zip(self.image_path[self.img_id_list],
                               self.img_label[self.img_id_list],
                               self.image_attribute_labels[self.img_id_list], 
                               self.id_to_parts[self.img_id_list],
                               self.object_xywh[self.img_id_list]):

            A_ = np.where(A==1)[0]
            for y in A_:
                idx = np.where(attribute_part_label[y]==1)[0][0]
                apl = np.sum(attribute_part_label[idx],0)*np.array(z)[idx][2]
      
            self.sub_meta.append({'path':x, 'part': z, 'mask': apl, 'class':l, 'box':B, 'att_label': A})
                
    def renumber_index(self, label):
        if self.renumber==True:
            return self.label_transform_list[label]
        else:
            return label    
        
    def __getitem__(self,idx):
        image_file = os.path.join(self.sub_meta[idx]['path'])
               
        data_numpy = plt.imread(image_file)/255.
        
        if(len(data_numpy.shape)!=3):
            data_numpy = np.tile(data_numpy[...,None],(1,1,3))

        origin_len = data_numpy.shape[:-1]
        data_numpy = skTr.resize(data_numpy, (self.image_size, self.image_size))
        
        joints_vis = self.sub_meta[idx]['part']
        
        joints_vis_ = np.array(joints_vis)        
        joints_vis = np.copy(joints_vis_)
        
        joints_vis[...,0] = joints_vis_[...,1]*(self.image_size/origin_len[0])
        joints_vis[...,1] = joints_vis_[...,0]*(self.image_size/origin_len[1])
         
                    
        obj_yxwh = np.array(self.sub_meta[idx]['box'])
        obj_xyhw = np.zeros_like(obj_yxwh)
        
        obj_xyhw[...,0] = obj_yxwh[...,1]*(self.image_size/origin_len[0])
        obj_xyhw[...,1] = obj_yxwh[...,0]*(self.image_size/origin_len[1])
        obj_xyhw[...,2] = obj_yxwh[...,3]*(self.image_size/origin_len[0])
        obj_xyhw[...,3] = obj_yxwh[...,2]*(self.image_size/origin_len[1])
        
        class_label = self.sub_meta[idx]['class']
        img_attribute_label = self.sub_meta[idx]['att_label']
        class_label = np.array(class_label)
        class_label = self.renumber_index(class_label).astype(np.int64)
        return data_numpy.astype(np.float32).transpose([2,0,1]), joints_vis.astype(np.float32), obj_xyhw.astype(np.float32), class_label, img_attribute_label.astype(np.float32)
        
    def __len__(self):
        return len(self.img_id_list)
    
    def get_class_attribute(self, return_tensor=False):
            if(self.renumber == True):
                #print(self.renumbered_label[10:15])
                if(return_tensor):
                    return torch.from_numpy(self.attribute[self.renumbered_label].astype(np.float32))
                else:
                    return self.attribute[self.renumbered_label].astype(np.float32)
            else:
                if(return_tensor):
                    return torch.from_numpy(self.attribute.astype(np.float32))
                else:
                    return self.attribute.astype(np.float32)

